//===----------------------------------------------------------------------===//
//
//                         BusTub
//
// nested_loop_join_executor.cpp
//
// Identification: src/execution/nested_loop_join_executor.cpp
//
// Copyright (c) 2015-2021, Carnegie Mellon University Database Group
//
//===----------------------------------------------------------------------===//

#include "execution/executors/nested_loop_join_executor.h"
#include "binder/table_ref/bound_join_ref.h"
#include "common/exception.h"
#include "type/value_factory.h"

namespace bustub {

NestedLoopJoinExecutor::NestedLoopJoinExecutor(ExecutorContext *exec_ctx, const NestedLoopJoinPlanNode *plan,
                                               std::unique_ptr<AbstractExecutor> &&left_executor,
                                               std::unique_ptr<AbstractExecutor> &&right_executor)
    : AbstractExecutor(exec_ctx),
      plan_(plan),
      left_executor_(std::move(left_executor)),
      right_executor_(std::move(right_executor)) {
  if (plan->GetJoinType() != JoinType::LEFT && plan->GetJoinType() != JoinType::INNER) {
    // Note for 2023 Spring: You ONLY need to implement left join and inner join.
    throw bustub::NotImplementedException(fmt::format("join type {} not supported", plan->GetJoinType()));
  }
}

void NestedLoopJoinExecutor::Init() {
  left_executor_->Init();  // 左右初始化. 先获取左边的left_tuple_
  right_executor_->Init();
  RID left_rid;
  left_return_ = left_executor_->Next(&left_tuple_, &left_rid);
}

auto NestedLoopJoinExecutor::Next(Tuple *tuple, RID *rid) -> bool {
  if (plan_->GetJoinType() == JoinType::LEFT) {
    return NextForLeftJoin(tuple, rid);
  }
  return NextForInnerJoin(tuple, rid);
}

auto NestedLoopJoinExecutor::NextForInnerJoin(Tuple *tuple, RID *rid) -> bool {
  int left_col = 0;
  int right_col = 0;
  std::vector<Value> res_values;
  while (true) {
    if (!left_return_) {  // 如果左边结果为false, 则结束
      return false;
    }
    RID right_rid;
    bool right_return = right_executor_->Next(&right_tuple_, &right_rid);  // 获取一个 right, 如果成功, 则继续
    if (!right_return) {  // 如果不成功, 则right 重新开始, left 递进
      right_executor_->Init();
      RID left_rid;
      left_return_ = left_executor_->Next(&left_tuple_, &left_rid);
      continue;
    }
    auto predicate = plan_->Predicate();
    auto evaluate_join_ret =
        predicate->EvaluateJoin(&left_tuple_, left_executor_->GetOutputSchema(),  // ComparisonExpression::EvaluateJoin
                                &right_tuple_, right_executor_->GetOutputSchema());
    if ((!evaluate_join_ret.IsNull() && evaluate_join_ret.GetAs<bool>())) {  // 如果谓词成功, 则返回左右的值
      left_col = left_executor_->GetOutputSchema().GetColumnCount();

      for (int i = 0; i < left_col; i++) {  // 左表的值
        res_values.push_back(left_tuple_.GetValue(&(left_executor_->GetOutputSchema()), static_cast<uint32_t>(i)));
      }

      right_col = right_executor_->GetOutputSchema().GetColumnCount();
      for (int i = 0; i < right_col; i++) {  // 右表的值
        res_values.push_back(right_tuple_.GetValue(&(right_executor_->GetOutputSchema()), static_cast<uint32_t>(i)));
      }
      *tuple = Tuple(res_values, &(GetOutputSchema()));
      return true;
    }
  }
  return false;
}

auto NestedLoopJoinExecutor::NextForLeftJoin(Tuple *tuple, RID *rid) -> bool {
  int left_col = 0;
  int right_col = 0;
  std::vector<Value> res_values;
  while (true) {
    if (!left_return_) {  // 如果左边结果为false, 则结束
      return false;
    }
    RID right_rid;
    bool right_return = right_executor_->Next(&right_tuple_, &right_rid);  // 获取一个 right, 如果成功, 则继续
    if (!right_return) {  // 如果不成功, 则right 重新开始, left 递进
      if (left_joined_.count(left_tuple_.GetRid()) != 0) {
        right_executor_->Init();
        RID left_rid;
        left_return_ = left_executor_->Next(&left_tuple_, &left_rid);
        continue;
      }
      left_col = left_executor_->GetOutputSchema().GetColumnCount();
      for (int i = 0; i < left_col; i++) {  // 左表的值
        res_values.push_back(left_tuple_.GetValue(&(left_executor_->GetOutputSchema()), static_cast<uint32_t>(i)));
        // LOG_DEBUG("left_res_values=%s", res_values[i].ToString().c_str());
      }

      // NULL
      right_col = right_executor_->GetOutputSchema().GetColumnCount();
      for (int i = 0; i < right_col; i++) {  // 右表的值
        res_values.push_back(
            ValueFactory::GetNullValueByType(right_executor_->GetOutputSchema().GetColumn(i).GetType()));
        // LOG_DEBUG("right_res_values=%s", res_values[i].ToString().c_str());
      }
      *tuple = Tuple(res_values, &(GetOutputSchema()));
      // LOG_DEBUG("tuple=%s", tuple->ToString(&GetOutputSchema()).c_str());

      right_executor_->Init();
      RID left_rid;
      left_return_ = left_executor_->Next(&left_tuple_, &left_rid);
      return true;
    }
    auto predicate = plan_->Predicate();
    auto evaluate_join_ret =
        predicate->EvaluateJoin(&left_tuple_, left_executor_->GetOutputSchema(),  // ComparisonExpression::EvaluateJoin
                                &right_tuple_, right_executor_->GetOutputSchema());
    if ((!evaluate_join_ret.IsNull() && evaluate_join_ret.GetAs<bool>())) {  // 如果谓词成功, 则返回左右的值
      left_joined_.insert(left_tuple_.GetRid());
      left_col = left_executor_->GetOutputSchema().GetColumnCount();

      for (int i = 0; i < left_col; i++) {  // 左表的值
        res_values.push_back(left_tuple_.GetValue(&(left_executor_->GetOutputSchema()), static_cast<uint32_t>(i)));
        LOG_DEBUG("left_res_values=%s", res_values[i].ToString().c_str());
      }

      right_col = right_executor_->GetOutputSchema().GetColumnCount();
      for (int i = 0; i < right_col; i++) {  // 右表的值
        res_values.push_back(right_tuple_.GetValue(&(right_executor_->GetOutputSchema()), static_cast<uint32_t>(i)));
        LOG_DEBUG("right_res_values=%s", res_values[i].ToString().c_str());
      }
      *tuple = Tuple(res_values, &(GetOutputSchema()));
      LOG_DEBUG("tuple=%s", tuple->ToString(&GetOutputSchema()).c_str());
      return true;
    }
  }
  // LOG_DEBUG("left_cursor_=%d, right_cursor_=%d", left_cursor_, right_cursor_);
  return false;
}

}  // namespace bustub
